{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using TensorFlow backend.\n",
      "/home/leon/miniconda3/lib/python3.6/importlib/_bootstrap.py:219: RuntimeWarning: compiletime version 3.5 of module 'tensorflow.python.framework.fast_tensor_util' does not match runtime version 3.6\n",
      "  return f(*args, **kwds)\n"
     ]
    }
   ],
   "source": [
    "import numpy as np\n",
    "import json\n",
    "from keras.models import Model\n",
    "from keras.layers import Input, Activation\n",
    "from keras.layers.convolutional import Conv2D, SeparableConv2D, ZeroPadding2D, Cropping2D\n",
    "from keras.layers.pooling import AveragePooling2D\n",
    "from keras.layers.normalization import BatchNormalization\n",
    "from keras import backend as K\n",
    "from collections import OrderedDict"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "def format_decimal(arr, places=6):\n",
    "    return [round(x * 10**places) / 10**places for x in arr]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "DATA = OrderedDict()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### pipeline 16"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "random_seed = 10016\n",
    "data_in_shape = (10,10,2)\n",
    "\n",
    "layers = [\n",
    "    Conv2D(4, (3,3), padding='valid', strides=(2,2), use_bias=False),\n",
    "    BatchNormalization(epsilon=1e-03, axis=-1, center=True, scale=True),\n",
    "    Activation('relu'),\n",
    "    ZeroPadding2D(padding=((0, 1), (0, 1))),\n",
    "    Cropping2D(cropping=((1, 0), (1, 0))),\n",
    "    AveragePooling2D(pool_size=(1, 1), padding='valid', strides=(2, 2))\n",
    "]\n",
    "\n",
    "input_layer = Input(shape=data_in_shape)\n",
    "x = layers[0](input_layer)\n",
    "for layer in layers[1:-1]:\n",
    "    x = layer(x)\n",
    "output_layer = layers[-1](x)\n",
    "model = Model(inputs=input_layer, outputs=output_layer)\n",
    "\n",
    "np.random.seed(random_seed)\n",
    "data_in = 2 * np.random.random(data_in_shape) - 1\n",
    "\n",
    "# set weights to random (use seed for reproducibility)\n",
    "weights = []\n",
    "for i, w in enumerate(model.get_weights()):\n",
    "    np.random.seed(random_seed + i)\n",
    "    if i == 4:\n",
    "        # std should be positive\n",
    "        weights.append(np.random.random(w.shape))\n",
    "    else:\n",
    "        weights.append(2 * np.random.random(w.shape) - 1)\n",
    "model.set_weights(weights)\n",
    "\n",
    "result = model.predict(np.array([data_in]))\n",
    "data_out_shape = result[0].shape\n",
    "data_in_formatted = format_decimal(data_in.ravel().tolist())\n",
    "data_out_formatted = format_decimal(result[0].ravel().tolist())\n",
    "\n",
    "DATA['pipeline_16'] = {\n",
    "    'input': {'data': data_in_formatted, 'shape': data_in_shape},\n",
    "    'weights': [{'data': format_decimal(w.ravel().tolist()), 'shape': w.shape} for w in weights],\n",
    "    'expected': {'data': data_out_formatted, 'shape': data_out_shape}\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### export for Keras.js tests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "filename = '../../test/data/pipeline/16.json'\n",
    "if not os.path.exists(os.path.dirname(filename)):\n",
    "    os.makedirs(os.path.dirname(filename))\n",
    "with open(filename, 'w') as f:\n",
    "    json.dump(DATA, f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\"pipeline_16\": {\"input\": {\"data\": [-0.296519, 0.572188, -0.272211, -0.907259, -0.996198, 0.61739, 0.110548, 0.580901, -0.115354, 0.574469, -0.03939, 0.348826, -0.042372, 0.132598, -0.620505, 0.302978, -0.907868, 0.723517, -0.666202, -0.655146, -0.169661, -0.721866, 0.65103, -0.379465, 0.187194, 0.910984, -0.046393, -0.717226, -0.220449, -0.812018, 0.758602, -0.425723, 0.826973, 0.41357, -0.982014, 0.308827, 0.930297, 0.552791, 0.422404, -0.45334, 0.271448, 0.563136, -0.316318, 0.641678, 0.363059, -0.440587, -0.037204, -0.119982, 0.826628, -0.369548, 0.06503, 0.559157, 0.855733, 0.156159, 0.690997, 0.38689, 0.197399, -0.520678, -0.60102, -0.519463, 0.437547, -0.963725, 0.840374, -0.961812, 0.682657, 0.683939, 0.6797, 0.823475, -0.155763, 0.407472, 0.890151, -0.275848, -0.321692, -0.669862, 0.847674, 0.467442, -0.767249, -0.762515, 0.856726, -0.318873, 0.694542, 0.714746, -0.143808, -0.049621, -0.475187, -0.353159, 0.74349, 0.257662, -0.882949, -0.242672, -0.152523, -0.014676, 0.055601, 0.544025, -0.020051, -0.797163, -0.304542, 0.614325, -0.53998, 0.375662, -0.705527, -0.520286, -0.7861, -0.567627, -0.367321, 0.327227, 0.283708, 0.457965, 0.984347, -0.337644, -0.675635, -0.201958, -0.396271, 0.353544, 0.598736, -0.595224, 0.687218, 0.23673, 0.212353, 0.278087, 0.918215, -0.460093, 0.023297, -0.691769, -0.754752, -0.38847, 0.136175, 0.73544, -0.589597, -0.838913, 0.145886, 0.840253, -0.86358, 0.11517, -0.739268, -0.728574, -0.113248, -0.218071, -0.380281, -0.124292, 0.260543, -0.949342, 0.21152, 0.003548, 0.119453, 0.589717, 0.369072, 0.470988, 0.829205, 0.815702, -0.537322, -0.762229, 0.627759, 0.996532, 0.312333, -0.70115, -0.706727, 0.248581, -0.005799, 0.076031, -0.705038, 0.652309, -0.132503, -0.653827, -0.347404, -0.962842, -0.264451, 0.624884, -0.894047, -0.42713, 0.970572, 0.377402, -0.190215, 0.065945, -0.620644, 0.938643, 0.312769, 0.309219, 0.45483, 0.912505, 0.654109, 0.086082, -0.044841, -0.662166, 0.669878, -0.421224, -0.663752, 0.737268, -0.133045, 0.045642, 0.608814, 0.486256, 0.537827, 0.787655, -0.942509, -0.339053, -0.01045, -0.764766, -0.35555, -0.729608], \"shape\": [10, 10, 2]}, \"weights\": [{\"data\": [-0.296519, 0.572188, -0.272211, -0.907259, -0.996198, 0.61739, 0.110548, 0.580901, -0.115354, 0.574469, -0.03939, 0.348826, -0.042372, 0.132598, -0.620505, 0.302978, -0.907868, 0.723517, -0.666202, -0.655146, -0.169661, -0.721866, 0.65103, -0.379465, 0.187194, 0.910984, -0.046393, -0.717226, -0.220449, -0.812018, 0.758602, -0.425723, 0.826973, 0.41357, -0.982014, 0.308827, 0.930297, 0.552791, 0.422404, -0.45334, 0.271448, 0.563136, -0.316318, 0.641678, 0.363059, -0.440587, -0.037204, -0.119982, 0.826628, -0.369548, 0.06503, 0.559157, 0.855733, 0.156159, 0.690997, 0.38689, 0.197399, -0.520678, -0.60102, -0.519463, 0.437547, -0.963725, 0.840374, -0.961812, 0.682657, 0.683939, 0.6797, 0.823475, -0.155763, 0.407472, 0.890151, -0.275848], \"shape\": [3, 3, 2, 4]}, {\"data\": [-0.983721, 0.040472, -0.976147, -0.023768], \"shape\": [4]}, {\"data\": [0.524698, 0.474407, 0.471901, 0.222861], \"shape\": [4]}, {\"data\": [-0.040091, -0.549528, 0.347624, 0.057999], \"shape\": [4]}, {\"data\": [0.14783, 0.198852, 0.395548, 0.513733], \"shape\": [4]}], \"expected\": {\"data\": [0.261207, 0.531501, 4.084157, 0.351517, 0.429899, 0.814022, 3.722062, 0.208077, 0.0, 0.393933, 3.271786, 0.231966, 0.034173, 0.340105, 0.0, 0.254722], \"shape\": [2, 2, 4]}}}\n"
     ]
    }
   ],
   "source": [
    "print(json.dumps(DATA))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
